import argparse
import mmengine
import os
from tqdm import tqdm
import random
from vllm import LLM, SamplingParams
import teval.evaluators as evaluator_factory
from teval.utils.meta_template import meta_template_dict
from typing import List, Dict, Optional, Union
import os
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
import mmengine
import os
from format_json import *

os.environ["CUDALIB_PATH"] = "/miniforge3/envs/vllm/bin"
os.environ["TRITON_PTXAS_PATH"] = os.path.join(os.environ["CUDALIB_PATH"], "ptxas")
os.environ["TRITON_CUOBJDUMP_PATH"] = os.path.join(os.environ["CUDALIB_PATH"], "cuobjdump")
os.environ["TRITON_NVDISASM_PATH"] = os.path.join(os.environ["CUDALIB_PATH"], "nvdisasm")
os.environ['TOKENIZERS_PARALLELISM'] = 'false'



class LMTemplateParser:
    """Intermediate prompt template parser, specifically for language models.

    Args:
        meta_template (list of dict, optional): The meta template for the model.
    """

    def __init__(self, meta_template: Optional[List[Dict]] = None):
        self.meta_template = meta_template
        if meta_template:
            assert isinstance(meta_template, list)
            self.roles: Dict[str, dict] = dict()  # Maps role name to config
            for item in meta_template:
                assert isinstance(item, dict)
                assert item['role'] not in self.roles, 'role in meta prompt must be unique!'
                self.roles[item['role']] = item.copy()

    def __call__(self, dialog: List[Dict]) -> str:
        """Parse a prompt template and wrap it with the meta template if applicable.

        Args:
            dialog (List[Dict]): A list representing the dialog sequence.

        Returns:
            str: The final formatted string.
        """
        assert isinstance(dialog, list)
        if self.meta_template:
            prompt = ''
            for index, item in enumerate(dialog):
                if isinstance(item, dict):
                    new_str = self._prompt2str(item, index == len(dialog) - 1)
                    prompt += new_str
        else:
            # In case the model does not have any meta template
            prompt = ''
            last_sep = ''
            for item in dialog:
                if isinstance(item, dict) and item.get('content', ''):
                    prompt += last_sep + item.get('content', '')
                last_sep = '\n'
        return prompt

    def _format_begin(self, role_cfg, message):
        """Helper function to handle the 'begin' section of a template."""
        name = message.get('name', None)
        if name is not None:
            begin = role_cfg['begin'].get('with_name', '')
            if name in role_cfg['begin'].get('name', {}):
                begin = begin.format(name=role_cfg['begin']['name'][name])
            else:
                begin = begin.format(name=name)
        else:
            if isinstance(role_cfg.get('begin', ''), str):
                begin = role_cfg.get('begin', '')
            elif isinstance(role_cfg['begin'], dict):
                begin = role_cfg['begin'].get('without_name', '')
        return begin

    def _prompt2str(self, prompt: Dict, last: bool = False) -> str:
        """Convert a single prompt dictionary to a formatted string."""
        merged_prompt = self.roles.get(prompt['role'], {})
        begin = self._format_begin(merged_prompt, prompt)
        res = begin
        if last and merged_prompt.get('generate', False):
            res += prompt.get('content', '')
            return res
        res += prompt.get('content', '') + merged_prompt.get('end', '')
        if last and merged_prompt['role'] != 'assistant':
            res += self._format_begin(self.roles['assistant'], {})
        return res


def generate_prompt(model_name: str, dialog: List[Dict], meta_template_dict: Dict[str, List[Dict]]) -> str:
    """Generate a formatted prompt string based on the model name and dialog list.

    Args:
        model_name (str): The name of the model (e.g., 'internlm', 'llama2').
        dialog (List[Dict]): The list of dialog messages.
        meta_template_dict (Dict[str, List[Dict]]): The template configurations for different models.

    Returns:
        str: The generated prompt string.
    """
    if model_name not in meta_template_dict:
        raise ValueError(f"Unsupported model name: {model_name}")

    template_parser = LMTemplateParser(meta_template_dict[model_name])
    return template_parser(dialog)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_path', type=str, default='/workflowllm/T-Eval-main/data/plan_json_v2.json')
    parser.add_argument('--model_display_name', type=str, default="Llama3.1")
    parser.add_argument('--resume', action='store_true')
    parser.add_argument('--out_name', type=str, default='tmp.json')
    parser.add_argument('--out_dir', type=str, default="data/work_dirs/")
    parser.add_argument('--model_path', type=str, help="path to the model", default="facebook/opt-125m")
    parser.add_argument('--test_num', type=int, default=-1, help='number of samples to test, -1 means all')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--eval', type=str, default='plan', choices=['instruct', 'reason', 'plan', 'retrieve', 'review', 'understand', 'rru'])
    parser.add_argument('--prompt_type', type=str, default='json', choices=['json', 'str', 'python'])
    parser.add_argument('--meta_template', type=str, default='llama3_1')

    parser.add_argument("--debug", action="store_true")

    args = parser.parse_args()
    return args

def load_dataset(dataset_path, out_dir, is_resume=False, tmp_folder_name='tmp'):
    dataset = mmengine.load(dataset_path)
    total_num = len(dataset)
    tested_num = 0
    if is_resume:
        file_list = os.listdir(os.path.join(out_dir, tmp_folder_name))
        for filename in file_list:
            if filename.split('.')[0] in dataset:
                tested_num += 1
                file_id = filename.split('.')[0]
                dataset.pop(file_id)
            else:
                print(f"Warning: {filename} not in dataset, removing it from cache")
                os.remove(os.path.join(out_dir, tmp_folder_name, filename))

    return dataset, tested_num, total_num

def split_special_tokens(text):
    text = text.split('<eoa>')[0]
    text = text.split('<TOKENS_UNUSED_1>')[0]
    text = text.split('<|im_end|>')[0]
    text = text.split('\nuser')[0]
    text = text.split('\nassistant')[0]
    text = text.split('\nUSER')[0]
    text = text.split('[INST]')[0]
    text = text.split('<|user|>')[0]
    text = text.strip()
    if text.startswith('```json'):
        text = text[len('```json'):]
    text = text.strip('`').strip()
    return text

def infer_vllm(dataset, path, out_dir, model_name, tmp_folder_name='tmp', test_num=1, batch_size=1):
    # Initialize the vLLM engine for inference
    engine = LLM(model=path, tensor_parallel_size=1)

    random_list = list(dataset.keys())[:test_num]
    batch_infer_list = []
    batch_infer_ids = []

    for idx in tqdm(random_list):
        prompt = dataset[idx]['origin_prompt']
        batch_infer_list.append(prompt)
        batch_infer_ids.append(idx)

        if len(batch_infer_ids) == batch_size or idx == len(random_list) - 1:
            # 生成 prompts 并执行推理
            # sampling_params = SamplingParams(temperature=1.0, top_p=0.9, max_tokens=8192)
            # sampling_params = SamplingParams(temperature=0.2, top_p=1.0, max_tokens=2048)

            sampling_params = SamplingParams(
                temperature=0.1,
                # top_p=1.0,
                max_tokens=512,
                repetition_penalty=2.0,
                stop=[engine.llm_engine.tokenizer.tokenizer.eos_token]
            )

            if isinstance(prompt, list):
                results = engine.generate([generate_prompt(model_name, dialog, meta_template_dict) for dialog in batch_infer_list], sampling_params=sampling_params)

            elif isinstance(prompt, str):
                results = engine.generate([dialog for dialog in batch_infer_list], sampling_params=sampling_params)

            else:
                raise Exception
            for ptr, prediction in enumerate(results):
                generated_text = prediction.outputs[0].text
                print(generated_text)
                generated_text = split_special_tokens(generated_text)
                data_ptr = batch_infer_ids[ptr]
                dataset[data_ptr]['prediction'] = generated_text
                mmengine.dump(dataset[data_ptr], os.path.join(out_dir, tmp_folder_name, f'{data_ptr}.json'))

            batch_infer_ids = []
            batch_infer_list = []

    results = dict()
    file_list = os.listdir(os.path.join(out_dir, tmp_folder_name))
    for filename in file_list:
        file_id = filename.split('.')[0]
        results[file_id] = mmengine.load(os.path.join(out_dir, tmp_folder_name, filename))
    return results


def infer_transformers_multi_gpu(
    dataset,
    path,
    out_dir,
    model_name,
    tmp_folder_name='tmp',
    test_num=1,
    batch_size=1,
):

    # Initialize Accelerator for multi-GPU support
    accelerator = Accelerator()
    device = accelerator.device

    use_flash_attention = True
    device_index = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(path, use_fast=False)

    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    # Load model with appropriate settings
    model = AutoModelForCausalLM.from_pretrained(
        path,
        torch_dtype=torch.bfloat16,
        use_cache=False,
        attn_implementation="flash_attention_2" if use_flash_attention else None,
        device_map={"": accelerator.process_index},
    )

    model.config.pretraining_tp = 1
    model.eval()

    os.makedirs(os.path.join(out_dir, tmp_folder_name), exist_ok=True)

    # Prepare dataset
    random_list = list(dataset.keys())[:test_num]

    # Synchronize GPUs before starting inference
    accelerator.wait_for_everyone()

    # Split the data across processes
    with accelerator.split_between_processes(random_list) as eval_dataset:
        eval_dataset = list(eval_dataset)
        total_batches = len(eval_dataset) // batch_size + int(len(eval_dataset) % batch_size != 0)

        progress_bar = tqdm(
            range(total_batches),
            desc=f"Process {accelerator.process_index}",
            leave=False,
            disable=not accelerator.is_local_main_process
        )


        prefix = ''

        bad_words = ["Thought", "Code", 'python']
        bad_words_ids = [tokenizer.encode(word, add_special_tokens=False) for word in bad_words]

        for i in progress_bar:
            batch_indices = eval_dataset[i * batch_size: (i + 1) * batch_size]
            # batch_infer_list = [dataset[idx]['origin_prompt'] for idx in batch_indices]

            if isinstance(dataset[random_list[0]]['origin_prompt'], list) and args.prompt_type == 'json':
                prefix = '[{\n    "id":0,\n    "name": "'
                batch_infer_list = [generate_prompt(model_name, dataset[idx]['origin_prompt'], meta_template_dict ) + prefix for idx in batch_indices]
            elif isinstance(dataset[random_list[0]]['origin_prompt'], list) and args.prompt_type == 'str':
                batch_infer_list = [generate_prompt(model_name, dataset[idx]['origin_prompt'], meta_template_dict ) for idx in batch_indices]
            elif isinstance(dataset[random_list[0]]['origin_prompt'], str):
                batch_infer_list = [dataset[idx]['origin_prompt'] for idx in batch_indices]
            else:
                raise Exception

            batch_infer_ids = batch_indices

            # Tokenize inputs
            tokenized_inputs = tokenizer(
                batch_infer_list,
                return_tensors='pt',
                padding=True,
                truncation=True,
                max_length=2048,  # Adjust as needed
                add_special_tokens=False
            )
            tokenized_inputs = {key: value.to(device) for key, value in tokenized_inputs.items()}

            # Generate outputs
            with torch.inference_mode():
                outputs = model.generate(
                    input_ids=tokenized_inputs['input_ids'],
                    attention_mask=tokenized_inputs['attention_mask'],
                    max_new_tokens=512,  # Corresponds to max_tokens in vLLM
                    temperature=0.05,
                    # repetition_penalty=2.0,
                    eos_token_id=tokenizer.eos_token_id,
                    pad_token_id=tokenizer.pad_token_id,
                    num_return_sequences=1,
                    bad_words_ids=bad_words_ids 

                )

            # Process and save results
            for ptr in range(len(batch_infer_ids)):
                idx = batch_infer_ids[ptr]
                input_ids_len = tokenized_inputs['input_ids'][ptr].shape[0]
                generated_ids = outputs[ptr][input_ids_len:]
                generated_text = prefix + tokenizer.decode(generated_ids, skip_special_tokens=True)
                print(generated_text)
                # Apply any special token processing if needed
                # generated_text = split_special_tokens(generated_text)
                dataset[idx]['prediction'] = generated_text
                mmengine.dump(
                    dataset[idx],
                    os.path.join(out_dir, tmp_folder_name, f'{idx}.json')
                )

    # Load cached results
    results = dict()
    file_list = os.listdir(os.path.join(out_dir, tmp_folder_name))
    for filename in file_list:
        file_id = filename.split('.')[0]
        results[file_id] = mmengine.load(
            os.path.join(out_dir, tmp_folder_name, filename)
        )
    return results


def evaluate_predictions(args, output_file_path):
    # Ensure the model_display_name is set correctly
    if args.model_display_name == "":
        model_display_name = args.model_type
    else:
        model_display_name = args.model_display_name

    # Ensure output directory exists
    os.makedirs(args.out_dir, exist_ok=True)

    # Mapping evaluation types to evaluators
    eval_mapping = dict(
        instruct="InstructEvaluator",
        plan="PlanningEvaluator",
        review="ReviewEvaluator",
        reason="ReasonRetrieveUnderstandEvaluator",
        retrieve="ReasonRetrieveUnderstandEvaluator",
        understand="ReasonRetrieveUnderstandEvaluator",
        rru="ReasonRetrieveUnderstandEvaluator"
    )

    # Set the correct bert_score_model and json_path based on dataset language
    if "_zh" in args.dataset_path:
        bert_score_model = "thenlper/gte-large-zh"
        json_path = os.path.join(args.out_dir, model_display_name + '_' + str(args.test_num) + '_zh.json')
        raise Exception('Not implemented!')
    else:
        bert_score_model = "/Pretrained_Language_Models/all-mpnet-base-v2"
        json_path = os.path.join(args.out_dir, model_display_name + '_' + str(args.test_num) + '.json')

    # Get the appropriate evaluator class
    evaluator_class = getattr(evaluator_factory, eval_mapping[args.eval])
    evaluator = evaluator_class(output_file_path, default_prompt_type=args.prompt_type, eval_type=args.eval, bert_score_model=bert_score_model)

    # Load previous results if available
    if os.path.exists(json_path):
        results = mmengine.load(json_path)
    else:
        results = dict()

    # Perform evaluation
    eval_results = evaluator.evaluate()
    print(eval_results)

    # Save evaluation results
    results[args.eval + '_' + args.prompt_type] = eval_results
    print(f"Writing Evaluation Results to {json_path}")
    mmengine.dump(results, json_path)


if __name__ == '__main__':
    args = parse_args()
    os.makedirs(args.out_dir, exist_ok=True)
    tmp_folder_name = os.path.splitext(args.out_name)[0]
    os.makedirs(os.path.join(args.out_dir, tmp_folder_name), exist_ok=True)


    output_file_path = os.path.join(args.out_dir, args.out_name)

    print(output_file_path)
    if not os.path.exists(output_file_path):
        dataset, tested_num, total_num = load_dataset(args.dataset_path, args.out_dir, args.resume,
                                                      tmp_folder_name=tmp_folder_name)
        if args.test_num == -1:
            test_num = max(total_num - tested_num, 0)
        else:
            test_num = max(min(args.test_num - tested_num, total_num - tested_num), 0)

        tensor_parallel_size = 1
        if args.debug:
            tensor_parallel_size = 1

        if test_num != 0:
            print(f"Tested {tested_num} samples, left {test_num} samples, total {total_num} samples")
            # prediction = infer_vllm(dataset, args.model_path, args.out_dir, args.meta_template , tmp_folder_name=tmp_folder_name, test_num=test_num, batch_size=args.batch_size)
            prediction = infer_transformers_multi_gpu(dataset, args.model_path, args.out_dir, args.meta_template,
                                                      tmp_folder_name=tmp_folder_name, test_num=test_num,
                                                      batch_size=args.batch_size)


            for key, value in prediction.items():
                prediction[key]['prediction'] = remove_comments_manual(prediction[key]['prediction'])
                prediction[key]['prediction'] = extract_triplets(prediction[key]['prediction'])

            # Save predictions
            mmengine.dump(prediction, output_file_path)


    # Perform evaluation if requested
    if args.eval:
        evaluate_predictions(args, output_file_path)